"""
FEniCS tutorial demo program: Poisson equation with Dirichlet conditions.
Test problem is chosen to give an exact solution at all nodes of the mesh.

  -Laplace(u) = f    in the unit square
            u = u_D  on the boundary

  u_D = 1 + x^2 + 2y^2
    f = -6


ref: https://www.iesensor.com/blog/2017/05/24/gmsh_fenics_meshing/
     https://comphysblog.wordpress.com/2018/08/15/fenics-2d-electrostatics-with-imported-mesh-and-boundaries/
"""

from __future__ import print_function
from fenics import *
from dolfin import *
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
#from vtkplotter.dolfin import plot
import numpy as np
import math

class PlasmaHModel:
    def __init__(self, efit):
        self.efit = efit

    def run(self):
        #fname = "t1"
        fname = "hl2a"
        mesh = Mesh(fname+".xml")
        if os.path.exists( fname+"_physical_region.xml"):
            subdomains = MeshFunction("size_t", mesh, fname+"_physical_region.xml")
            #plot(subdomains)
        if os.path.exists( fname+"_facet_region.xml"):
            boundaries = MeshFunction("size_t", mesh, fname+"_facet_region.xml")
            #plot(boundaries)

        #plot(mesh)
        #plt.show()


        print("xml mesh reading done")


        V = FunctionSpace(mesh, 'P', 1)
        VV = VectorFunctionSpace(mesh, "P", 1)

        dofmap = V.dofmap()
        dofs = dofmap.dofs()
        # Get coordinates as len(dofs) x gdim array
        gdim = 2
        dofs_x = V.tabulate_dof_coordinates().reshape((-1, gdim))
        print("shape of dofs_x: ", dofs_x.shape)


        b = Function(VV)
        b0 = Function(V)
        b1 = Function(V)
        #b_array = np.array(b.vector())
        b0_array = b0.vector().get_local()
        b1_array = b1.vector().get_local()
        print("shape of b0_array: ", b0_array.shape)

        for i in range(b0_array.shape[0]):
            x0 = dofs_x[i,0]
            x1 = dofs_x[i,1]
            b_temp = self.efit.get_B_rz(np.array([x0,x1]))

            b0_array[i] = b_temp[0] / math.sqrt(b_temp[0] * b_temp[0] + b_temp[1] * b_temp[1])
            b1_array[i] = b_temp[1] / math.sqrt(b_temp[0] * b_temp[0] + b_temp[1] * b_temp[1])
            #print("b_array ", b_array[i])
        b0.vector().set_local(b0_array)
        b1.vector().set_local(b1_array)
        assign(b, [b0, b1])


        # Define boundary condition
        u_D = Expression('1 + x[0]*x[0] + 2*x[1]*x[1]', degree=2)

        def boundary(x, on_boundary):
            return on_boundary

        #the parameter after boundaries, see hl2a_facet_region.xml file
        inner_edge_boundary = DirichletBC(V, Constant(1.0e18), boundaries, 384)
        #first_wall_boundary = DirichletBC(V, Constant(1.0e16), boundaries, 385)
        divertor_ll_boundary = DirichletBC(V, Constant(1.0e16), boundaries, 386)
        divertor_lr_boundary = DirichletBC(V, Constant(1.0e16), boundaries,387)
        #dome_l_boundary = DirichletBC(V, Constant(10.0), boundaries, 388)

        bcs =[inner_edge_boundary, divertor_ll_boundary, divertor_lr_boundary]
        #bcs =[inner_edge_boundary, first_wall_boundary, divertor_ll_boundary, divertor_lr_boundary, dome_l_boundary]


        #diffusion coeffiecent
        D = 10.0
        Dp = 10000.0
        Dr = 1.0

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        f = Constant(0.0)
        #a = D * dot(grad(u), grad(v))*dx
        ap = Dp * inner( b * dot(b, nabla_grad(u)), grad(v))*dx
        ar = Dr * inner(grad(u), grad(v))*dx - Dr * inner(b * dot(b, nabla_grad(u)) , grad(v)) * dx
        a = ar + ap
        L = f*v*dx

        # Compute solution
        n = Function(V)
        solve(a == L, n, bcs)

        # Plot solution and mesh
        #plot(mesh)
        c = plot(n, mode='color', cmap = mpl.cm.Spectral_r)
        plt.colorbar(c)
        #plot(b)
        
        #vtk plot
        #plot(n, mode='color', style=1)
        #plot(b, mode='color', style=1)

        # Hold plot
        plt.show()
